import os
import torch
from torch.utils.data.dataloader import DataLoader

datasets = ['alamy', 'featurepics', 'freepik', 'istockphoto']
split_nums = [10, 3, 11, 12]

# indir = '/raid/data/dochen/top-host'
# for split_num, ds in zip(split_nums, datasets):
#     word_freq_dict = None
#     for idx in range(split_num):
#         curr_dict = torch.load(os.path.join(indir, '%s_illustration'%ds, 'word_freq_%03d.pth' % idx))
#         if idx == 0:
#             word_freq_dict = curr_dict
#         else:
#             for k,v in curr_dict.items():
#                 if k in word_freq_dict:
#                     word_freq_dict[k] += v
#                 else:
#                     word_freq_dict[k] = v
    
#     torch.save(word_freq_dict, os.path.join(indir, '%s_illustration'%ds, '%s_word_freq_combined.pth' % ds))

# process for each data
# indir = '/mnt2/datasets/tophost-art'
# for ds in datasets:
#     word_freq_dict = torch.load(os.path.join(indir, '%s_word_freq_combined.pth' % ds))
#     word_freq_list = []
#     for k, v in word_freq_dict.items():
#         word_freq_list.append((k, v))
#     word_freq_list = sorted(word_freq_list, key=lambda x: -x[1])

#     total_freq = sum([x[1] for x in word_freq_list])
#     print('total len: %d, total freq: %d' % (len(word_freq_list), total_freq))

#     with open(os.path.join(indir, '%s_word_freq_rank.txt'), 'w') as fid:
#         for wf in word_freq_list:
#             fid.write('%s, %d, %f\n' % (wf[0], wf[1], wf[1]/float(total_freq)))

from scripts.tools.filter_illustration_data import TSVTextDataset

indir = '/mnt2/datasets/tophost-art'
for ds in datasets:
    tsv_dataset = TSVTextDataset(args.dataset_name, text_tsv_list, data_root=args.data_root, text_format='json')
    tsv_loader = DataLoader(tsv_dataset, batch_size=1, shuffle=False, num_workers=2)

    print('searching good text idxs')
    for idx, text in enumerate(tsv_loader):
        if idx % 10000 == 0:
            print('%d/%d' % (idx, len(tsv_loader)))
        text = text[0].lower()
        words = set(text.split())
        for word in words:
            word_freq[word]+=1

    word_freq_dict = torch.load(os.path.join(indir, '%s_word_freq_combined.pth' % ds))
    word_freq_list = []
    for k, v in word_freq_dict.items():
        word_freq_list.append((k, v))
    word_freq_list = sorted(word_freq_list, key=lambda x: -x[1])

    total_freq = sum([x[1] for x in word_freq_list])
    print('total len: %d, total freq: %d' % (len(word_freq_list), total_freq))

    with open(os.path.join(indir, '%s_word_freq_rank.txt'), 'w') as fid:
        for wf in word_freq_list:
            fid.write('%s, %d, %f\n' % (wf[0], wf[1], wf[1]/float(total_freq)))